{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Query Rewriting (Azure AI Search)\n",
    "\n",
    "This code demonstrates how to use Azure AI Search with advanced query rewriting to improve the relevance of your search results. The code performs the following tasks:\n",
    "\n",
    "+ Create an index schema\n",
    "+ Load the sample data from a local folder\n",
    "+ Embed the documents in-memory using Azure OpenAI's text-embedding-ada-002 model\n",
    "+ Index the vector and nonvector fields on Azure AI Search\n",
    "+ Rewrite a sample question to improve the relevance of the result documents\n",
    "+ Manually combine the results of multiple rewritten queries using [Reciprocal Rank Fusion (RRF)](https://learn.microsoft.com/azure/search/hybrid-search-ranking).\n",
    "+ Use [simple query syntax](https://learn.microsoft.com/azure/search/query-simple-syntax) and [multi-vector queries](https://learn.microsoft.com/azure/search/vector-search-how-to-query?tabs=query-2023-11-01%2Cfilter-2023-11-01#multiple-vector-queries) to automatically combine multiple rewritten queries using built-in RRF\n",
    "\n",
    "The code uses Azure OpenAI to generate embeddings for title and content fields. You'll need access to Azure OpenAI to run this demo.\n",
    "\n",
    "The code reads the `text-sample.json` file, which contains the input data for which embeddings need to be generated.\n",
    "\n",
    "The output is a combination of human-readable text and embeddings that can be pushed into a search index.\n",
    "\n",
    "## Prerequisites\n",
    "\n",
    "- An Azure subscription, with [access to Azure OpenAI](https://aka.ms/oai/access). This sample uses two models.\n",
    "\n",
    "  - Specify [2023-12-01-preview REST API](https://learn.microsoft.com/azure/ai-services/openai/reference) or later when providing an Azure OpenAI endpoint.\n",
    "\n",
    "  - Specify a deployment of the `text-embedding-3-large` embedding model. As a naming convention, we name deployments after the model name: \"text-embedding-3-large\".\n",
    "  \n",
    "  - Specify a deployment of a chat model, such as gpt-4o or gpt-4o-mini. This example uses structured outputs to return a valid JSON object, which requires a specific version of a chat model.\n",
    "  \n",
    "    - [Review supported models](https://learn.microsoft.com/azure/ai-services/openai/how-to/json-mode?tabs=python#supported-models) for chat models supporting JSON mode. Note the model version number. If you already have a deployment, verify the model version is listed as a supported model.\n",
    "  \n",
    "    - [Check regional availability](https://learn.microsoft.com/azure/ai-services/openai/concepts/models#standard-deployment-model-availability) of the chat models. Make sure your Azure OpenAI resource is in a region that supports the model.\n",
    "\n",
    "- Azure AI Search, any tier and region, but you must have Basic or higher to try the semantic ranker. This example creates an index. Check your index quota to make sure you have room. [Enable semantic ranking](https://learn.microsoft.com/azure/search/semantic-how-to-enable-disable) before running the hybrid query with semantic ranking.\n",
    "\n",
    "We used Python 3.11, [Visual Studio Code with the Python extension](https://code.visualstudio.com/docs/python/python-tutorial), and the [Jupyter extension](https://marketplace.visualstudio.com/items?itemName=ms-toolsai.jupyter) to test this example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up a Python virtual environment in Visual Studio Code\n",
    "\n",
    "1. Open the Command Palette (Ctrl+Shift+P).\n",
    "1. Search for **Python: Create Environment**.\n",
    "1. Select **Venv**.\n",
    "1. Select a Python interpreter. Choose 3.10 or later.\n",
    "\n",
    "It can take a minute to set up. If you run into problems, see [Python environments in VS Code](https://code.visualstudio.com/docs/python/environments)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install -r query-rewrite-requirements.txt --quiet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up your environment variables\n",
    "\n",
    "The demo-python folder contains a `.env-sample` file that you can modify for your environment variables.\n",
    "\n",
    "Remember to omit API keys if you're using Azure role-based permissions. On Azure AI Search, you should have Search Service Contributor, Search Index Data Contributor, and Search Index Data Reader permissions. On Azure OpenAI, you should have Cognitive Services Contributor permissions.\n",
    "\n",
    "For this notebook, provide environment variables for endpoints and deployed models. \n",
    "\n",
    "Use the `.env` file in the parent `demo-python/code` folder or create a separate `.env` file in the `semantic-ranker-query-rewrite` sample folder.\n",
    "\n",
    "| Variable Name | Actual, suggested, or placeholder value |\n",
    "|---------------|---------------------------\n",
    "| AZURE_SEARCH_SERVICE_ENDPOINT | PLACEHOLDER FOR YOUR SEARCH SERVICE ENDPOINT |\n",
    "| AZURE_SEARCH_INDEX | PLACEHOLDER FOR AN INDEX NAME | \n",
    "| AZURE_SEARCH_ADMIN_KEY | Omit the key if you're using role-based access controls. | \n",
    "| AZURE_OPENAI_ENDPOINT  | PLACEHOLDER FOR YOUR AZURE OPENAI ENDPOINT |\n",
    "| AZURE_OPENAI_KEY=  | Omit the key if you're using role-based access controls. | \n",
    "| AZURE_OPENAI_API_VERSION | `2024-10-21` 2024-07-18 and later is required for JSON mode. |\n",
    "| AZURE_OPENAI_EMBEDDING_DEPLOYMENT | `text-embedding-3-large` or any embedding model on Azure OpenAI |\n",
    "| AZURE_OPENAI_CHATGPT_DEPLOYMENT | `gpt-4o-mini` or any chat model on Azure OpenAI. Remember to check model version and regional availability. |\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import required libraries and environment variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dotenv import load_dotenv\n",
    "from azure.identity import DefaultAzureCredential\n",
    "from azure.core.credentials import AzureKeyCredential\n",
    "import os\n",
    "\n",
    "load_dotenv(override=True) # take environment variables from .env.\n",
    "\n",
    "# The following variables from your .env file are used in this notebook\n",
    "endpoint = os.environ[\"AZURE_SEARCH_SERVICE_ENDPOINT\"]\n",
    "admin_key = os.getenv(\"AZURE_SEARCH_ADMIN_KEY\")\n",
    "credential = DefaultAzureCredential() if not admin_key else AzureKeyCredential(admin_key)\n",
    "index_name = os.getenv(\"AZURE_SEARCH_INDEX\", \"qr-example\")\n",
    "azure_openai_endpoint = os.environ[\"AZURE_OPENAI_ENDPOINT\"]\n",
    "aoai_key = os.getenv(\"AZURE_OPENAI_KEY\")\n",
    "azure_openai_embedding_deployment = os.getenv(\"AZURE_OPENAI_EMBEDDING_DEPLOYMENT\", \"text-embedding-3-large\")\n",
    "azure_openai_api_version = os.getenv(\"AZURE_OPENAI_API_VERSION\", \"2024-10-21\")\n",
    "azure_openai_chatgpt_deployment = os.getenv(\"AZURE_OPENAI_CHATGPT_DEPLOYMENT\", \"gpt-4o\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create embeddings\n",
    "Read your data, generate OpenAI embeddings and export to a format to insert your Azure AI Search index:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from openai import AzureOpenAI\n",
    "from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n",
    "import json\n",
    "\n",
    "openai_credential = DefaultAzureCredential()\n",
    "token_provider = get_bearer_token_provider(openai_credential, \"https://cognitiveservices.azure.com/.default\")\n",
    "\n",
    "client = AzureOpenAI(\n",
    "    api_version=azure_openai_api_version,\n",
    "    azure_endpoint=azure_openai_endpoint,\n",
    "    api_key=aoai_key,\n",
    "    azure_ad_token_provider=token_provider if not aoai_key else None\n",
    ")\n",
    "\n",
    "output_path = os.path.join('output', 'docVectors.json')\n",
    "\n",
    "if not os.path.exists(output_path):\n",
    "    # Generate Document Embeddings using OpenAI 3 large\n",
    "    # Read the text-sample.json\n",
    "    path = os.path.join('..', '..', '..', 'data', 'text-sample.json')\n",
    "    with open(path, 'r', encoding='utf-8') as file:\n",
    "        input_data = json.load(file)\n",
    "\n",
    "    titles = [item['title'] for item in input_data]\n",
    "    content = [item['content'] for item in input_data]\n",
    "    title_response = client.embeddings.create(input=titles, model=azure_openai_embedding_deployment, dimensions=1024)\n",
    "    title_embeddings = [item.embedding for item in title_response.data]\n",
    "    content_response = client.embeddings.create(input=content, model=azure_openai_embedding_deployment, dimensions=1024)\n",
    "    content_embeddings = [item.embedding for item in content_response.data]\n",
    "\n",
    "    # Generate embeddings for title and content fields\n",
    "    for i, item in enumerate(input_data):\n",
    "        title = item['title']\n",
    "        content = item['content']\n",
    "        item['titleVector'] = title_embeddings[i]\n",
    "        item['contentVector'] = content_embeddings[i]\n",
    "\n",
    "    # Output embeddings to docVectors.json file\n",
    "    output_directory = os.path.dirname(output_path)\n",
    "    if not os.path.exists(output_directory):\n",
    "        os.makedirs(output_directory)\n",
    "    with open(output_path, \"w\") as f:\n",
    "        json.dump(input_data, f)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create your search index\n",
    "\n",
    "Create your search index schema and vector search configuration. If you get an error, check the search service for available quota and check the .env file to make sure you're using a unique search index name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "my-demo-index created\n"
     ]
    }
   ],
   "source": [
    "from azure.search.documents.indexes import SearchIndexClient\n",
    "from azure.search.documents.indexes.models import (\n",
    "    SimpleField,\n",
    "    SearchFieldDataType,\n",
    "    SearchableField,\n",
    "    SearchField,\n",
    "    VectorSearch,\n",
    "    HnswAlgorithmConfiguration,\n",
    "    VectorSearchProfile,\n",
    "    SemanticConfiguration,\n",
    "    SemanticPrioritizedFields,\n",
    "    SemanticField,\n",
    "    SemanticSearch,\n",
    "    SearchIndex,\n",
    "    AzureOpenAIVectorizer,\n",
    "    AzureOpenAIVectorizerParameters\n",
    ")\n",
    "\n",
    "\n",
    "# Create a search index\n",
    "index_client = SearchIndexClient(\n",
    "    endpoint=endpoint, credential=credential)\n",
    "fields = [\n",
    "    SimpleField(name=\"id\", type=SearchFieldDataType.String, key=True, sortable=True, filterable=True, facetable=False),\n",
    "    SearchableField(name=\"title\", type=SearchFieldDataType.String),\n",
    "    SearchableField(name=\"content\", type=SearchFieldDataType.String),\n",
    "    SearchableField(name=\"category\", type=SearchFieldDataType.String,\n",
    "                    filterable=True),\n",
    "    SearchField(name=\"titleVector\", type=SearchFieldDataType.Collection(SearchFieldDataType.Single),\n",
    "                searchable=True, stored=False, vector_search_dimensions=1024, vector_search_profile_name=\"myHnswProfile\"),\n",
    "    SearchField(name=\"contentVector\", type=SearchFieldDataType.Collection(SearchFieldDataType.Single),\n",
    "                searchable=True, stored=False, vector_search_dimensions=1024, vector_search_profile_name=\"myHnswProfile\"),\n",
    "]\n",
    "\n",
    "# Configure the vector search configuration  \n",
    "vector_search = VectorSearch(\n",
    "    algorithms=[\n",
    "        HnswAlgorithmConfiguration(\n",
    "            name=\"myHnsw\"\n",
    "        )\n",
    "    ],\n",
    "    profiles=[\n",
    "        VectorSearchProfile(\n",
    "            name=\"myHnswProfile\",\n",
    "            algorithm_configuration_name=\"myHnsw\",\n",
    "            vectorizer_name=\"myVectorizer\"\n",
    "        )\n",
    "    ],\n",
    "    vectorizers=[\n",
    "        AzureOpenAIVectorizer(\n",
    "            vectorizer_name=\"myVectorizer\",\n",
    "            parameters=AzureOpenAIVectorizerParameters(\n",
    "                resource_url=azure_openai_endpoint,\n",
    "                deployment_name=azure_openai_embedding_deployment,\n",
    "                api_key=aoai_key,\n",
    "                model_name=azure_openai_embedding_deployment\n",
    "            )\n",
    "        )\n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "semantic_config = SemanticConfiguration(\n",
    "    name=\"my-semantic-config\",\n",
    "    prioritized_fields=SemanticPrioritizedFields(\n",
    "        content_fields=[SemanticField(field_name=\"content\")]\n",
    "    )\n",
    ")\n",
    "\n",
    "# Create the semantic settings with the configuration\n",
    "semantic_search = SemanticSearch(configurations=[semantic_config])\n",
    "\n",
    "# Create the search index with the semantic settings\n",
    "index = SearchIndex(name=index_name, fields=fields,\n",
    "                    vector_search=vector_search, semantic_search=semantic_search)\n",
    "result = index_client.create_or_update_index(index)\n",
    "print(f'{result.name} created')\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Insert text and embeddings into vector store\n",
    "Add texts and metadata from the JSON data to the vector store:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from azure.search.documents import SearchClient\n",
    "\n",
    "search_client = SearchClient(endpoint=endpoint, index_name=index_name, credential=credential)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Uploaded 108 documents\n"
     ]
    }
   ],
   "source": [
    "# Upload some documents to the index\n",
    "output_path = os.path.join('output', 'docVectors.json')\n",
    "output_directory = os.path.dirname(output_path)\n",
    "if not os.path.exists(output_directory):\n",
    "    os.makedirs(output_directory)\n",
    "with open(output_path, 'r') as file:  \n",
    "    documents = json.load(file)  \n",
    "search_client = SearchClient(endpoint=endpoint, index_name=index_name, credential=credential)\n",
    "result = search_client.upload_documents(documents)\n",
    "print(f\"Uploaded {len(documents)} documents\") "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Retrieve chunks using hybrid search\n",
    "\n",
    "Before evaluating the effects of query rewriting, it's useful to establish a baseline as to what hybrid search returns without any query rewriting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from azure.search.documents.models import VectorizableTextQuery\n",
    "\n",
    "def hybrid_search(search_client: SearchClient, query: str) -> pd.DataFrame:\n",
    "    results = search_client.search(\n",
    "        search_text=query,\n",
    "        vector_queries=[\n",
    "            # k_nearest_neighbors should be set to 50 in order to boost the relevance of hybrid search\n",
    "            # Increasing the vector recall set size from 1 to 50 in hybrid search benefits relevance by\n",
    "            # improving the diversity of vector query results that will be considered by RRF, ensuring a more comprehensive representation\n",
    "            # of the data results and more robustness to varying similarity scores or closely related similarity scores.\n",
    "            VectorizableTextQuery(text=query, k_nearest_neighbors=50, fields=\"contentVector\")\n",
    "        ],\n",
    "        top=3,\n",
    "        select=\"id, title, content\",\n",
    "        search_fields=[\"content\"]\n",
    "    )\n",
    "    data = [[result[\"id\"], result[\"title\"], result[\"content\"], result[\"@search.score\"]] for result in results]\n",
    "    return pd.DataFrame(data, columns=[\"id\", \"title\", \"content\", \"@search.score\"])\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following cell demonstrates the results of hybrid search using a sample query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>title</th>\n",
       "      <th>content</th>\n",
       "      <th>@search.score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4</td>\n",
       "      <td>Azure Storage</td>\n",
       "      <td>Azure Storage is a scalable, durable, and high...</td>\n",
       "      <td>0.033333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>36</td>\n",
       "      <td>Azure Data Lake Storage</td>\n",
       "      <td>Azure Data Lake Storage is a scalable, secure,...</td>\n",
       "      <td>0.032266</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>52</td>\n",
       "      <td>Azure Table Storage</td>\n",
       "      <td>Azure Table Storage is a fully managed, NoSQL ...</td>\n",
       "      <td>0.031250</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                    title  \\\n",
       "0   4            Azure Storage   \n",
       "1  36  Azure Data Lake Storage   \n",
       "2  52      Azure Table Storage   \n",
       "\n",
       "                                             content  @search.score  \n",
       "0  Azure Storage is a scalable, durable, and high...       0.033333  \n",
       "1  Azure Data Lake Storage is a scalable, secure,...       0.032266  \n",
       "2  Azure Table Storage is a fully managed, NoSQL ...       0.031250  "
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hybrid_search(search_client, \"scalable storage solution\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use built-in query rewriting\n",
    "\n",
    "Search offers [query rewriting](https://learn.microsoft.com/azure/search/semantic-how-to-query-rewrite) built-in with usage of the [semantic ranker](https://learn.microsoft.com/azure/search/semantic-how-to-query-request). Evaluate this first before trying other solutions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>title</th>\n",
       "      <th>content</th>\n",
       "      <th>@search.score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>40</td>\n",
       "      <td>Azure Cognitive Search</td>\n",
       "      <td>Azure Cognitive Search is a fully managed sear...</td>\n",
       "      <td>0.033333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>Azure Cognitive Services</td>\n",
       "      <td>Azure Cognitive Services are a set of AI servi...</td>\n",
       "      <td>0.032002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>90</td>\n",
       "      <td>Azure Cognitive Services</td>\n",
       "      <td>Azure Cognitive Services is a collection of AI...</td>\n",
       "      <td>0.031545</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                     title  \\\n",
       "0  40    Azure Cognitive Search   \n",
       "1   3  Azure Cognitive Services   \n",
       "2  90  Azure Cognitive Services   \n",
       "\n",
       "                                             content  @search.score  \n",
       "0  Azure Cognitive Search is a fully managed sear...       0.033333  \n",
       "1  Azure Cognitive Services are a set of AI servi...       0.032002  \n",
       "2  Azure Cognitive Services is a collection of AI...       0.031545  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['search engine services', 'online search engine services', 'online search services']\n"
     ]
    }
   ],
   "source": [
    "from typing import Optional\n",
    "\n",
    "# Workaround to use query writes with debugging form the Python SDK\n",
    "import azure.search.documents._generated.models\n",
    "azure.search.documents._generated.models.SearchDocumentsResult._attribute_map[\"debug_info\"][\"key\"] = \"@search\\\\.debug\"\n",
    "from azure.search.documents._generated.models import DebugInfo\n",
    "import azure.search.documents._paging\n",
    "def get_debug_info(self) -> Optional[DebugInfo]:\n",
    "    self.continuation_token = None\n",
    "    return self._response.debug_info\n",
    "azure.search.documents._paging.SearchPageIterator.get_debug_info = azure.search.documents._paging._ensure_response(get_debug_info)\n",
    "azure.search.documents._paging.SearchItemPaged.get_debug_info = lambda self: self._first_iterator_instance().get_debug_info()\n",
    "\n",
    "search_client = SearchClient(endpoint=endpoint, index_name=index_name, credential=credential)\n",
    "\n",
    "results = search_client.search(\n",
    "    search_text=\"search service\",\n",
    "    # Issue a vector query for every single rewritten query\n",
    "    vector_queries=[VectorizableTextQuery(text=\"srch service\", k_nearest_neighbors=50, fields=\"contentVector\")],\n",
    "    query_type=\"semantic\",\n",
    "    semantic_configuration_name='my-semantic-config',\n",
    "    query_rewrites=\"generative|count-3\",\n",
    "    query_language=\"en\",\n",
    "    debug=\"queryRewrites\",\n",
    "    search_fields=[\"content\"],\n",
    "    top=3,\n",
    "    include_total_count=True\n",
    ")\n",
    "\n",
    "data = [[result[\"id\"], result[\"title\"], result[\"content\"], result[\"@search.score\"]] for result in results]\n",
    "df = pd.DataFrame(data, columns=[\"id\", \"title\", \"content\", \"@search.score\"])\n",
    "query_rewrites = results.get_debug_info().query_rewrites.text.rewrites\n",
    "\n",
    "display(df)\n",
    "print(query_rewrites)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Customize rewriting queries for improved relevance of results\n",
    "\n",
    "Users often use terse terms such as \"scalable storage solution\". These terms may match the contents of documents in the search index, but often an LLM can rewrite the query to improve the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import openai\n",
    "from pydantic import BaseModel\n",
    "\n",
    "class QueryRewrites(BaseModel):\n",
    "    queries: list[str]\n",
    "\n",
    "tools = [openai.pydantic_function_tool(QueryRewrites)]\n",
    "\n",
    "# This prompt can be customized to write the rewrites in a specific format or use specific words\n",
    "REWRITE_PROMPT = \"\"\"You are a helpful assistant. You help users search for the answers to their questions.\n",
    "You have access to Azure AI Search index with 100's of documents. Rewrite the following question into useful search queries to find the most relevant documents.\n",
    "The number of rewrites should be 3\n",
    "\"\"\"\n",
    "\n",
    "# If you are not using a supported model or region, you may not be able to use structured outputs\n",
    "# https://learn.microsoft.com/azure/ai-services/openai/how-to/structured-outputs\n",
    "def rewrite_query(openai_client: AzureOpenAI, query: str):\n",
    "    response = openai_client.chat.completions.create(\n",
    "        model=azure_openai_chatgpt_deployment,\n",
    "        messages=[\n",
    "            {\"role\": \"system\", \"content\": REWRITE_PROMPT},\n",
    "            {\"role\": \"user\", \"content\": query}\n",
    "        ],\n",
    "        tools=tools\n",
    "    )\n",
    "    \n",
    "    # The JSON is always valid because the function tool is set to use strict=True\n",
    "    return json.loads(response.choices[0].message.tool_calls[0].function.arguments)[\"queries\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following cell demonstrates how an LLM can rewrite queries to improve their clarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Azure Search', 'Azure Search definition', 'Azure Search explanation']"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rewrite_query(client, \"what is azure sarch?\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Combining the rewritten queries manually using RRF\n",
    "\n",
    "Now that we can use a LLM to rewrite the query, we need to issue our queries and combine the results. We'll start by doing this manually to demonstrate how the RRF calculation works"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def query_rewrite_manual_rrf(search_client: SearchClient, openai_client: AzureOpenAI, query: str) -> pd.DataFrame:\n",
    "    rewritten_queries = rewrite_query(openai_client, query)\n",
    "    # pd.concat preserves the original index by default when concatenating tables\n",
    "    # This is important for the RRF calculation below\n",
    "    results = pd.concat([hybrid_search(search_client, rewritten_query) for rewritten_query in rewritten_queries], axis=0)\n",
    "    def rrf_score(row: pd.Series) -> float:\n",
    "        score = 0.0\n",
    "        k = 60\n",
    "        # rank = the original position in the results list the document was located at\n",
    "        for rank, df_row in results.iterrows():\n",
    "            # The RRF score is the sum of 1.0 / (k + document rank) in every result set the document shows up in\n",
    "            if df_row[\"id\"] == row[\"id\"]:\n",
    "                score += 1.0 / (k + rank)\n",
    "        return score\n",
    "    # Apply the RRF scoring function to every row in the data frame\n",
    "    results[\"rrf_score\"] = results.apply(rrf_score, axis=1)\n",
    "    # Return the deduplicated result set sorted by the most relevant RRF score\n",
    "    return rewritten_queries, results.drop_duplicates(subset=[\"id\"]).sort_values(by=\"rrf_score\", ascending=False)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following cell demonstrates how an unclear query (\"srch service\") is automatically rewritten and made more clear by an LLM. The resulting RRF score is higher for the most relevant document compared to the original search score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>title</th>\n",
       "      <th>content</th>\n",
       "      <th>@search.score</th>\n",
       "      <th>rrf_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>40</td>\n",
       "      <td>Azure Cognitive Search</td>\n",
       "      <td>Azure Cognitive Search is a fully managed sear...</td>\n",
       "      <td>0.033333</td>\n",
       "      <td>0.049727</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>90</td>\n",
       "      <td>Azure Cognitive Services</td>\n",
       "      <td>Azure Cognitive Services is a collection of AI...</td>\n",
       "      <td>0.032522</td>\n",
       "      <td>0.048652</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>Azure Cognitive Services</td>\n",
       "      <td>Azure Cognitive Services are a set of AI servi...</td>\n",
       "      <td>0.032522</td>\n",
       "      <td>0.033060</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>91</td>\n",
       "      <td>Azure Bot Service</td>\n",
       "      <td>Azure Bot Service is a managed, AI-powered ser...</td>\n",
       "      <td>0.031778</td>\n",
       "      <td>0.016129</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                     title  \\\n",
       "0  40    Azure Cognitive Search   \n",
       "2  90  Azure Cognitive Services   \n",
       "1   3  Azure Cognitive Services   \n",
       "2  91         Azure Bot Service   \n",
       "\n",
       "                                             content  @search.score  rrf_score  \n",
       "0  Azure Cognitive Search is a fully managed sear...       0.033333   0.049727  \n",
       "2  Azure Cognitive Services is a collection of AI...       0.032522   0.048652  \n",
       "1  Azure Cognitive Services are a set of AI servi...       0.032522   0.033060  \n",
       "2  Azure Bot Service is a managed, AI-powered ser...       0.031778   0.016129  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Azure Cognitive Search', 'Azure search service', 'Microsoft search service']\n"
     ]
    }
   ],
   "source": [
    "from IPython.display import display\n",
    "\n",
    "rewritten_queries, results = query_rewrite_manual_rrf(search_client, client, \"srch service\")\n",
    "display(results)\n",
    "print(rewritten_queries)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Combining the rewritten queries automatically using RRF\n",
    "\n",
    "We can use the built-in RRF instead of manually performing the RRF calculation ourselves. We will use query combination using boolean operators and multi-vector search to accomplish a similar goal. Please note that the RRF score will not be exactly the same as the manual calculation because the text index can be more efficiently queried using this approach and less-relevant documents are automatically filtered out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def query_rewrite_automatic_rrf(search_client: SearchClient, openai_client: AzureOpenAI, query: str) -> pd.DataFrame:\n",
    "    rewritten_queries = rewrite_query(openai_client, query)\n",
    "    # Quote the rewritten queries before joining them in the query syntax\n",
    "    formatted_queries = [f'\"{rewritten_query}\"' for rewritten_query in rewritten_queries]\n",
    "    # Use the OR operator to join rewritten queries together\n",
    "    # https://learn.microsoft.com/azure/search/query-lucene-syntax#bkmk_boolean\n",
    "    search_text = \" | \".join(formatted_queries)\n",
    "    results = search_client.search(\n",
    "        search_text=search_text,\n",
    "        # Issue a vector query for every single rewritten query\n",
    "        vector_queries=[VectorizableTextQuery(text=rewritten_query, k_nearest_neighbors=50, fields=\"contentVector\") for rewritten_query in rewritten_queries],\n",
    "        query_type=\"simple\",\n",
    "        # Any rewritten query from the joined query could match\n",
    "        search_mode=\"any\",\n",
    "        search_fields=[\"content\"],\n",
    "        top=3\n",
    "    )\n",
    "    # @search.score is equivalent to the manually computed RRF score above\n",
    "    data = [[result[\"id\"], result[\"title\"], result[\"content\"], result[\"@search.score\"]] for result in results]\n",
    "    return rewritten_queries, pd.DataFrame(data, columns=[\"id\", \"title\", \"content\", \"@search.score\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following cell demonstrates how the automatic approach has similar results to the manual one, even though the scores are not exactly equal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>title</th>\n",
       "      <th>content</th>\n",
       "      <th>@search.score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>40</td>\n",
       "      <td>Azure Cognitive Search</td>\n",
       "      <td>Azure Cognitive Search is a fully managed sear...</td>\n",
       "      <td>0.050000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>90</td>\n",
       "      <td>Azure Cognitive Services</td>\n",
       "      <td>Azure Cognitive Services is a collection of AI...</td>\n",
       "      <td>0.048916</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>Azure Cognitive Services</td>\n",
       "      <td>Azure Cognitive Services are a set of AI servi...</td>\n",
       "      <td>0.048652</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                     title  \\\n",
       "0  40    Azure Cognitive Search   \n",
       "1  90  Azure Cognitive Services   \n",
       "2   3  Azure Cognitive Services   \n",
       "\n",
       "                                             content  @search.score  \n",
       "0  Azure Cognitive Search is a fully managed sear...       0.050000  \n",
       "1  Azure Cognitive Services is a collection of AI...       0.048916  \n",
       "2  Azure Cognitive Services are a set of AI servi...       0.048652  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Azure Cognitive Search service', 'How to use search service in Azure', 'Implementing search service in Azure']\n"
     ]
    }
   ],
   "source": [
    "rewritten_queries, results = query_rewrite_automatic_rrf(search_client, client, \"srch service\")\n",
    "display(results)\n",
    "print(rewritten_queries)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Continue to improve relevance using hybrid and semantic\n",
    "\n",
    "Once you are using the automatic RRF combination method, you can add semantic ranking to improve relevance further"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def query_rewrite_automatic_rrf_semantic(search_client: SearchClient, openai_client: AzureOpenAI, query: str) -> pd.DataFrame:\n",
    "    rewritten_queries = rewrite_query(openai_client, query)\n",
    "    # Quote the rewritten queries before joining them together using the query syntax\n",
    "    formatted_queries = [f'\"{rewritten_query}\"' for rewritten_query in rewritten_queries]\n",
    "    # Use the OR operator to join rewritten queries together\n",
    "    # https://learn.microsoft.com/azure/search/query-lucene-syntax#bkmk_boolean\n",
    "    search_text = \" | \".join(formatted_queries)\n",
    "    # The semantic ranker expects plain text queries with no search operators\n",
    "    semantic_query = \" \".join(rewritten_queries)\n",
    "    results = search_client.search(\n",
    "        search_text=search_text,\n",
    "        # Issue a vector query for every single rewritten query\n",
    "        vector_queries=[VectorizableTextQuery(text=rewritten_query, k_nearest_neighbors=50, fields=\"contentVector\") for rewritten_query in rewritten_queries],\n",
    "        # Any rewritten query from the joined query could match\n",
    "        search_mode=\"any\",\n",
    "        search_fields=[\"content\"],\n",
    "        query_type=\"simple\",\n",
    "        # Pass in the plain text concatenation of the rewritten queries for semantic ranking\n",
    "        semantic_query=semantic_query,\n",
    "        semantic_configuration_name='my-semantic-config',\n",
    "        top=3\n",
    "    )\n",
    "    # @search.score is equivalent to the manually computed RRF score above\n",
    "    # @search.rerankerscore is the semantic reranking of the combined results\n",
    "    data = [[result[\"id\"], result[\"title\"], result[\"content\"], result[\"@search.score\"], result[\"@search.reranker_score\"]] for result in results]\n",
    "    return rewritten_queries, pd.DataFrame(data, columns=[\"id\", \"title\", \"content\", \"@search.score\", \"@search.reranker_score\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following cell demonstrates how the semantic score compares to the RRF score. The semantic score ranges from 0-4, where a higher score indicates higher relvance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>title</th>\n",
       "      <th>content</th>\n",
       "      <th>@search.score</th>\n",
       "      <th>@search.reranker_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>40</td>\n",
       "      <td>Azure Cognitive Search</td>\n",
       "      <td>Azure Cognitive Search is a fully managed sear...</td>\n",
       "      <td>0.050000</td>\n",
       "      <td>2.302983</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>90</td>\n",
       "      <td>Azure Cognitive Services</td>\n",
       "      <td>Azure Cognitive Services is a collection of AI...</td>\n",
       "      <td>0.048395</td>\n",
       "      <td>2.043511</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>Azure Cognitive Services</td>\n",
       "      <td>Azure Cognitive Services are a set of AI servi...</td>\n",
       "      <td>0.047907</td>\n",
       "      <td>1.956755</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                     title  \\\n",
       "0  40    Azure Cognitive Search   \n",
       "1  90  Azure Cognitive Services   \n",
       "2   3  Azure Cognitive Services   \n",
       "\n",
       "                                             content  @search.score  \\\n",
       "0  Azure Cognitive Search is a fully managed sear...       0.050000   \n",
       "1  Azure Cognitive Services is a collection of AI...       0.048395   \n",
       "2  Azure Cognitive Services are a set of AI servi...       0.047907   \n",
       "\n",
       "   @search.reranker_score  \n",
       "0                2.302983  \n",
       "1                2.043511  \n",
       "2                1.956755  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['search service', 'Azure search service', 'How to use Azure search service']\n"
     ]
    }
   ],
   "source": [
    "rewritten_queries, results = query_rewrite_automatic_rrf_semantic(search_client, client, \"srch service\")\n",
    "display(results)\n",
    "print(rewritten_queries)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
